import sys
sys.path.append('..')
from common.util import eval_perplexity

from common.optimizer import SGD
from common.trainer import RnnlmTrainer 

from ch06.rnnlm import Rnnlm
from dataset import ptb


batch_size = 20
wordvec_size = 100
hidden_size = 100
time_size = 35  
lr = 20.0
max_epoch = 4
max_grad = 0.25

corpus,word_to_id,id_to_word = ptb.load_data('train')
corpus_test,_,_ = ptb.load_data('test')
vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]

model = Rnnlm(vocab_size,wordvec_size,hidden_size)
optimizer = SGD(lr)
trainer = RnnlmTrainer(model,optimizer)

trainer.fit(xs,ts,max_epoch,batch_size,time_size,
            max_grad,eval_interval=20)
trainer.plot(ylim=(0,500))

# 基于测试数据进行评价
model.reset_state()
ppl_test = eval_perplexity(model, corpus_test)
print('test perplexity: ', ppl_test)

# 保存参数
model.save_params()

